{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\softwares\\python\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[-0.3553845286369324,\n",
       " 0.48077651858329773,\n",
       " 0.20319104194641113,\n",
       " 0.5691454410552979,\n",
       " 1.0333210229873657,\n",
       " -0.8313525319099426,\n",
       " 0.3445739150047302,\n",
       " -0.9092084765434265,\n",
       " -0.9920079708099365,\n",
       " 0.2351675182580948,\n",
       " 0.20946961641311646,\n",
       " 0.568382740020752,\n",
       " 0.8029783964157104,\n",
       " 0.017373034730553627,\n",
       " 1.8900421857833862,\n",
       " -0.6497050523757935,\n",
       " 1.0451767444610596,\n",
       " -1.3993682861328125,\n",
       " -1.0495959520339966,\n",
       " 0.8101091980934143,\n",
       " -0.5843707323074341,\n",
       " 0.6960948705673218,\n",
       " -1.121341347694397,\n",
       " -0.7917256951332092,\n",
       " -0.32384055852890015,\n",
       " 0.5696032643318176,\n",
       " -0.5509737730026245,\n",
       " -0.8218489289283752,\n",
       " 1.2479422092437744,\n",
       " 1.5368176698684692,\n",
       " 0.6405949592590332,\n",
       " -0.09939694404602051,\n",
       " -0.4520101547241211,\n",
       " 0.2920151352882385,\n",
       " 0.7910453081130981,\n",
       " -0.8409210443496704,\n",
       " 0.9998199939727783,\n",
       " 0.2853350043296814,\n",
       " -0.051960766315460205,\n",
       " -0.8492034077644348,\n",
       " 0.3049454092979431,\n",
       " -0.1277051866054535,\n",
       " -0.6475430727005005,\n",
       " 1.5440731048583984,\n",
       " 1.2596133947372437,\n",
       " 0.1145726665854454,\n",
       " -0.23241277039051056,\n",
       " 0.09554002434015274,\n",
       " -0.640026867389679,\n",
       " 0.3668650984764099,\n",
       " 0.4167139232158661,\n",
       " 9.323674201965332,\n",
       " 1.0335429906845093,\n",
       " 0.18567852675914764,\n",
       " 0.18941062688827515,\n",
       " 0.2760472297668457,\n",
       " 0.6644406914710999,\n",
       " -0.04668571427464485,\n",
       " 0.6310387253761292,\n",
       " -0.6239911317825317,\n",
       " 0.29827114939689636,\n",
       " -0.5037376880645752,\n",
       " 0.6605460047721863,\n",
       " 0.8312529921531677,\n",
       " 0.2759116590023041,\n",
       " -1.173952341079712,\n",
       " -0.7564942836761475,\n",
       " 0.07686419785022736,\n",
       " -1.6794054508209229,\n",
       " 0.1326194703578949,\n",
       " -0.2742536664009094,\n",
       " -0.21880793571472168,\n",
       " 0.3732129633426666,\n",
       " 0.6268978118896484,\n",
       " -0.32701337337493896,\n",
       " 1.5837972164154053,\n",
       " -0.11388660222291946,\n",
       " 0.49243152141571045,\n",
       " -0.2659205496311188,\n",
       " 0.34145408868789673,\n",
       " 0.3788055181503296,\n",
       " 0.0462348535656929,\n",
       " -0.34037572145462036,\n",
       " -0.08771151304244995,\n",
       " -0.3700566291809082,\n",
       " -0.16888514161109924,\n",
       " -1.166020393371582,\n",
       " -2.906568765640259,\n",
       " 0.8474512696266174,\n",
       " 0.09487336874008179,\n",
       " -0.2469932734966278,\n",
       " -0.48918581008911133,\n",
       " 0.08517653495073318,\n",
       " -0.13387355208396912,\n",
       " 0.8442482948303223,\n",
       " 0.4166167080402374,\n",
       " 0.4570929706096649,\n",
       " 0.236543670296669,\n",
       " 0.3077929615974426,\n",
       " -1.3154476881027222,\n",
       " -0.10159008204936981,\n",
       " 1.5106480121612549,\n",
       " 0.3591769337654114,\n",
       " 0.44267740845680237,\n",
       " -1.447298526763916,\n",
       " -0.40897583961486816,\n",
       " -0.240012526512146,\n",
       " 0.05491483211517334,\n",
       " -0.6508415937423706,\n",
       " -0.05004890263080597,\n",
       " -0.6896663308143616,\n",
       " -0.3250100612640381,\n",
       " 0.20398467779159546,\n",
       " 0.8824139833450317,\n",
       " -0.661309003829956,\n",
       " 0.22558452188968658,\n",
       " 0.2187669277191162,\n",
       " 0.27325594425201416,\n",
       " -2.1757946014404297,\n",
       " -0.49991631507873535,\n",
       " 0.39318469166755676,\n",
       " -0.9431384801864624,\n",
       " -0.3283321261405945,\n",
       " -1.2857781648635864,\n",
       " 0.5690340399742126,\n",
       " 0.3456863760948181,\n",
       " 0.269601434469223,\n",
       " 0.2621394991874695,\n",
       " 0.4348730146884918,\n",
       " -0.3657251298427582,\n",
       " 0.8200868368148804,\n",
       " 0.20787985622882843,\n",
       " -1.3209387063980103,\n",
       " -0.7475671768188477,\n",
       " 1.3142268657684326,\n",
       " 0.25076764822006226,\n",
       " -0.3302152156829834,\n",
       " -0.058092422783374786,\n",
       " -0.9038243889808655,\n",
       " -0.3607662320137024,\n",
       " -1.9044625759124756,\n",
       " -0.1048501506447792,\n",
       " -0.41016632318496704,\n",
       " 0.12212547659873962,\n",
       " 1.8709684610366821,\n",
       " 0.31169214844703674,\n",
       " 0.16428697109222412,\n",
       " -0.337072491645813,\n",
       " 0.748779833316803,\n",
       " -0.667922854423523,\n",
       " -0.9939472079277039,\n",
       " 0.2945834994316101,\n",
       " -0.3381521701812744,\n",
       " -0.9633947014808655,\n",
       " 0.04562638700008392,\n",
       " -1.5160499811172485,\n",
       " -0.11709229648113251,\n",
       " -0.5494530200958252,\n",
       " -0.12802891433238983,\n",
       " -0.5134710669517517,\n",
       " 0.40316271781921387,\n",
       " -0.0027923285961151123,\n",
       " 0.2523823380470276,\n",
       " -0.3106342554092407,\n",
       " -0.5982236266136169,\n",
       " -0.5562742948532104,\n",
       " 0.4363545775413513,\n",
       " 0.2277623414993286,\n",
       " 0.014365419745445251,\n",
       " 0.9567214846611023,\n",
       " -0.17476268112659454,\n",
       " 0.8630161285400391,\n",
       " -0.22958536446094513,\n",
       " 0.384452760219574,\n",
       " -0.18489745259284973,\n",
       " -0.5471782684326172,\n",
       " -1.0990999937057495,\n",
       " -0.17785325646400452,\n",
       " -1.1137510538101196,\n",
       " -0.4742060601711273,\n",
       " -0.26404765248298645,\n",
       " -0.2508077621459961,\n",
       " 0.7954322695732117,\n",
       " 0.19495368003845215,\n",
       " -1.2066025733947754,\n",
       " -0.6367514133453369,\n",
       " 0.20824065804481506,\n",
       " 0.5324366688728333,\n",
       " 0.3303869068622589,\n",
       " -0.04256126284599304,\n",
       " 0.05535763502120972,\n",
       " 0.4966875910758972,\n",
       " -3.0370590686798096,\n",
       " -0.15918244421482086,\n",
       " 0.9588220715522766,\n",
       " -0.022243447601795197,\n",
       " -0.24539262056350708,\n",
       " 0.28660595417022705,\n",
       " -0.7771015167236328,\n",
       " 0.5533612966537476,\n",
       " -0.12979058921337128,\n",
       " -0.8813213109970093,\n",
       " -0.1998516470193863,\n",
       " -0.17241635918617249,\n",
       " 0.10872595757246017,\n",
       " 0.06301712989807129,\n",
       " -0.330199658870697,\n",
       " 0.051028117537498474,\n",
       " 1.4092596769332886,\n",
       " 0.14378003776073456,\n",
       " -0.5994438529014587,\n",
       " -0.0631645917892456,\n",
       " -0.11973266303539276,\n",
       " 0.3327415883541107,\n",
       " 0.4168010652065277,\n",
       " 0.7578295469284058,\n",
       " 1.140609622001648,\n",
       " 0.16729629039764404,\n",
       " -0.4578864574432373,\n",
       " -0.6098126769065857,\n",
       " -0.02956642583012581,\n",
       " -0.21491742134094238,\n",
       " 0.03041325882077217,\n",
       " -0.470100075006485,\n",
       " -0.4646453857421875,\n",
       " 0.21842491626739502,\n",
       " 0.3790629804134369,\n",
       " 1.2161989212036133,\n",
       " -0.17067834734916687,\n",
       " -0.8734812140464783,\n",
       " 0.9440299868583679,\n",
       " -0.734499454498291,\n",
       " 0.08977989852428436,\n",
       " -0.7797722816467285,\n",
       " -0.42483335733413696,\n",
       " 1.3135936260223389,\n",
       " 0.9156976342201233,\n",
       " 1.070576548576355,\n",
       " 1.586364984512329,\n",
       " 0.19815421104431152,\n",
       " -0.21513767540454865,\n",
       " -0.10850861668586731,\n",
       " -0.4088248014450073,\n",
       " 0.02095058560371399,\n",
       " 0.29518356919288635,\n",
       " 0.9985461235046387,\n",
       " -0.40797847509384155,\n",
       " 1.3744020462036133,\n",
       " 0.5030481219291687,\n",
       " 0.7443089485168457,\n",
       " 0.7530710697174072,\n",
       " -1.2238355875015259,\n",
       " -0.14415033161640167,\n",
       " -1.1821070909500122,\n",
       " -0.6069613695144653,\n",
       " 0.042383335530757904,\n",
       " -0.24044936895370483,\n",
       " 0.5939999222755432,\n",
       " -0.6616140604019165,\n",
       " -0.725113034248352,\n",
       " 0.7450312376022339,\n",
       " -0.23663292825222015,\n",
       " 0.8649836778640747,\n",
       " 0.6657035946846008,\n",
       " 1.0658891201019287,\n",
       " 0.02483411133289337,\n",
       " 1.202069878578186,\n",
       " -0.3192175030708313,\n",
       " -0.05652960389852524,\n",
       " 0.09059248119592667,\n",
       " -0.2226908653974533,\n",
       " 0.14184322953224182,\n",
       " 0.4552706778049469,\n",
       " -0.20321650803089142,\n",
       " -1.6374750137329102,\n",
       " -0.7225655317306519,\n",
       " -0.38989174365997314,\n",
       " -0.13343602418899536,\n",
       " 1.6604522466659546,\n",
       " -0.6275085210800171,\n",
       " -0.4439527988433838,\n",
       " -1.0512635707855225,\n",
       " -0.10128361731767654,\n",
       " -0.5188666582107544,\n",
       " -0.7285888195037842,\n",
       " 0.2247726321220398,\n",
       " -0.09172359853982925,\n",
       " 0.5065785646438599,\n",
       " -0.2505120635032654,\n",
       " -0.5437736511230469,\n",
       " -1.2209360599517822,\n",
       " -0.16905809938907623,\n",
       " 1.2905056476593018,\n",
       " 0.1474626660346985,\n",
       " 1.5196324586868286,\n",
       " -0.1586109846830368,\n",
       " 0.3907361924648285,\n",
       " 0.530063271522522,\n",
       " 0.6744343638420105,\n",
       " 0.12537330389022827,\n",
       " -0.2166387140750885,\n",
       " -0.28755122423171997,\n",
       " -0.6846142411231995,\n",
       " 0.444118857383728,\n",
       " -0.18123029172420502,\n",
       " 0.36764660477638245,\n",
       " 0.4897822439670563,\n",
       " -0.29414427280426025,\n",
       " -0.4700987935066223,\n",
       " -0.9001128077507019,\n",
       " -0.13895396888256073,\n",
       " 0.18356889486312866,\n",
       " 0.5832914710044861,\n",
       " 0.5994372367858887,\n",
       " -1.204118251800537,\n",
       " -0.279898077249527,\n",
       " -0.15266762673854828,\n",
       " -0.7300839424133301,\n",
       " -0.7137472629547119,\n",
       " -0.1444990187883377,\n",
       " 0.13616178929805756,\n",
       " 1.4045747518539429,\n",
       " 0.9979355335235596,\n",
       " 0.5395926833152771,\n",
       " -0.630727231502533,\n",
       " -0.13742497563362122,\n",
       " 0.02875729836523533,\n",
       " 0.0925893783569336,\n",
       " -0.5604878664016724,\n",
       " 0.6567577719688416,\n",
       " 0.15880709886550903,\n",
       " 0.18433977663516998,\n",
       " 0.39860254526138306,\n",
       " -0.658923864364624,\n",
       " 0.42095112800598145,\n",
       " -0.6391074061393738,\n",
       " -0.9967461228370667,\n",
       " 0.5490021109580994,\n",
       " -1.0840919017791748,\n",
       " -0.6191527843475342,\n",
       " 0.6223672032356262,\n",
       " -0.6125235557556152,\n",
       " 0.046363651752471924,\n",
       " -0.2927088737487793,\n",
       " 0.5111204981803894,\n",
       " 0.5893677473068237,\n",
       " 0.07916495949029922,\n",
       " -0.049704380333423615,\n",
       " -0.463703989982605,\n",
       " 0.01592196524143219,\n",
       " -0.07617993652820587,\n",
       " -0.8885440826416016,\n",
       " 0.16271759569644928,\n",
       " 0.49966341257095337,\n",
       " -1.425042748451233,\n",
       " 1.350191593170166,\n",
       " -0.3482429087162018,\n",
       " 0.4390166401863098,\n",
       " 1.4691810607910156,\n",
       " -0.505117654800415,\n",
       " -0.02207024395465851,\n",
       " 0.2435045838356018,\n",
       " -0.27858424186706543,\n",
       " -1.2800487279891968,\n",
       " 0.13400837779045105,\n",
       " 0.11203321814537048,\n",
       " 0.029184848070144653,\n",
       " -0.14559368789196014,\n",
       " 0.8298282027244568,\n",
       " -0.1848202347755432,\n",
       " -0.3247588574886322,\n",
       " -0.5374534130096436,\n",
       " 0.7707842588424683,\n",
       " -2.034930467605591,\n",
       " -0.21449533104896545,\n",
       " 0.13922002911567688,\n",
       " 0.5783002972602844,\n",
       " 0.8369722962379456,\n",
       " 1.0840709209442139,\n",
       " -0.3214004337787628,\n",
       " -0.4447594881057739,\n",
       " 0.08200018107891083,\n",
       " 0.03591829910874367,\n",
       " -4.963113307952881,\n",
       " -1.271289587020874,\n",
       " 0.5557704567909241,\n",
       " -0.07848861813545227,\n",
       " -0.06229398399591446,\n",
       " 0.48098981380462646,\n",
       " -0.2350502908229828,\n",
       " 0.01635640114545822,\n",
       " 0.3788071870803833,\n",
       " -1.1280704736709595,\n",
       " -1.5513213872909546,\n",
       " 0.25959354639053345,\n",
       " -0.03401295840740204,\n",
       " -0.12645015120506287,\n",
       " 0.33434611558914185,\n",
       " 0.39239126443862915,\n",
       " -0.5666309595108032,\n",
       " 0.8780345916748047,\n",
       " 0.29495272040367126,\n",
       " 1.1588544845581055,\n",
       " -0.1235477551817894,\n",
       " -0.388215035200119,\n",
       " -1.1465522050857544,\n",
       " 0.7389386892318726,\n",
       " 0.6799932718276978,\n",
       " 0.337589830160141,\n",
       " -0.45959335565567017,\n",
       " -0.5274007320404053,\n",
       " 0.9517635107040405,\n",
       " 1.4802467823028564,\n",
       " -0.29663389921188354,\n",
       " -0.07074426114559174,\n",
       " 0.8924921751022339,\n",
       " -0.3509199619293213,\n",
       " -0.24480478465557098,\n",
       " 0.271799236536026,\n",
       " 0.035043299198150635,\n",
       " 0.11895253509283066,\n",
       " 0.48006778955459595,\n",
       " -1.6401340961456299,\n",
       " -0.16739264130592346,\n",
       " 0.6386302709579468,\n",
       " -1.2050458192825317,\n",
       " -0.3324015140533447,\n",
       " 0.02066326141357422,\n",
       " -0.4576408863067627,\n",
       " 0.001304909586906433,\n",
       " -0.35738539695739746,\n",
       " 0.27998411655426025,\n",
       " -0.39879533648490906,\n",
       " 1.2391215562820435,\n",
       " 0.2551766335964203,\n",
       " 0.8251870274543762,\n",
       " -0.2573036253452301,\n",
       " 0.17891621589660645,\n",
       " -0.4641304016113281,\n",
       " 0.5737677216529846,\n",
       " -0.27494651079177856,\n",
       " -1.2752572298049927,\n",
       " 0.05793040245771408,\n",
       " -0.6983433961868286,\n",
       " 1.3558775186538696,\n",
       " 0.8139014840126038,\n",
       " -0.3222547471523285,\n",
       " -0.5001085996627808,\n",
       " -0.19652096927165985,\n",
       " -0.5866294503211975,\n",
       " 0.07530062645673752,\n",
       " 0.1883566826581955,\n",
       " -0.6421807408332825,\n",
       " -0.08520635217428207,\n",
       " -0.11088652908802032,\n",
       " 0.8563172817230225,\n",
       " 0.31306129693984985,\n",
       " 0.18257609009742737,\n",
       " -1.4595413208007812,\n",
       " 1.7287943363189697,\n",
       " 0.6977348327636719,\n",
       " -0.24088339507579803,\n",
       " -0.20503896474838257,\n",
       " 0.11586253345012665,\n",
       " 0.5273114442825317,\n",
       " 0.6501744985580444,\n",
       " -0.010896433144807816,\n",
       " -0.6052039861679077,\n",
       " 0.8992083072662354,\n",
       " 0.5658880472183228,\n",
       " -0.8219771385192871,\n",
       " 1.3649823665618896,\n",
       " 1.169805884361267,\n",
       " -0.2748740315437317,\n",
       " 0.007704608142375946,\n",
       " -0.6281330585479736,\n",
       " 0.036086492240428925,\n",
       " -0.14534039795398712,\n",
       " -0.8053073883056641,\n",
       " -1.4780112504959106,\n",
       " 0.6261004209518433,\n",
       " -0.1535886973142624,\n",
       " 3.138734817504883,\n",
       " -0.18150851130485535,\n",
       " 0.7942280769348145,\n",
       " -0.3300071358680725,\n",
       " 0.06331518292427063,\n",
       " -0.40490832924842834,\n",
       " 0.739654004573822,\n",
       " -0.2109268754720688,\n",
       " -1.312638521194458,\n",
       " -0.04866161569952965,\n",
       " 0.21773704886436462,\n",
       " -0.8945596814155579,\n",
       " -0.6975258588790894,\n",
       " -0.29758596420288086,\n",
       " 0.622664213180542,\n",
       " 0.7264553308486938,\n",
       " 0.5403602123260498,\n",
       " 0.30341899394989014,\n",
       " -0.4647802412509918,\n",
       " 0.013007983565330505,\n",
       " 0.7188246846199036,\n",
       " -0.6101809740066528,\n",
       " -0.32536736130714417,\n",
       " -0.4207545220851898,\n",
       " -0.04891563206911087,\n",
       " -0.37479543685913086,\n",
       " -0.5295923948287964,\n",
       " -0.18404872715473175,\n",
       " -0.10606999695301056,\n",
       " -0.3718582093715668,\n",
       " -0.5333978533744812,\n",
       " -0.0714239776134491,\n",
       " 0.47699058055877686,\n",
       " 0.2406257688999176,\n",
       " -0.8900991678237915,\n",
       " 0.259172260761261,\n",
       " 0.15600469708442688,\n",
       " -0.11649131774902344,\n",
       " -0.9190261363983154,\n",
       " 0.393230676651001,\n",
       " 0.018305130302906036,\n",
       " -0.2400856763124466,\n",
       " -0.015356818214058876,\n",
       " -1.1202309131622314,\n",
       " 0.009183943271636963,\n",
       " -0.36865928769111633,\n",
       " 0.0912545770406723,\n",
       " 0.11445079743862152,\n",
       " -0.02469194307923317,\n",
       " -0.6311827898025513,\n",
       " -0.2528098225593567,\n",
       " -0.6022382378578186,\n",
       " 0.3728804588317871,\n",
       " 1.015621542930603,\n",
       " 1.673832654953003,\n",
       " -0.9932069778442383,\n",
       " -0.023778460919857025,\n",
       " -0.09745818376541138,\n",
       " 0.3937370777130127,\n",
       " 0.23129189014434814,\n",
       " 0.0722389668226242,\n",
       " -0.6996140480041504,\n",
       " -0.5496459603309631,\n",
       " -0.19679689407348633,\n",
       " -0.5857641100883484,\n",
       " -0.4490688145160675,\n",
       " -0.34892794489860535,\n",
       " 1.2350356578826904,\n",
       " -0.8410833477973938,\n",
       " 0.6379339694976807,\n",
       " -1.0814365148544312,\n",
       " 0.08338794112205505,\n",
       " -0.17947348952293396,\n",
       " 0.1772661954164505,\n",
       " -0.8063642382621765,\n",
       " -0.3828698694705963,\n",
       " 0.11873690783977509,\n",
       " -0.4265579581260681,\n",
       " 0.34920448064804077,\n",
       " -0.1913186013698578,\n",
       " 0.6905074119567871,\n",
       " 0.616620659828186,\n",
       " -0.27454811334609985,\n",
       " 0.1024179458618164,\n",
       " -0.5070933103561401,\n",
       " -0.08105158060789108,\n",
       " -1.0422515869140625,\n",
       " -1.1003609895706177,\n",
       " -0.3801124393939972,\n",
       " -0.27233558893203735,\n",
       " 0.0955924540758133,\n",
       " 0.41620030999183655,\n",
       " 0.5633823871612549,\n",
       " 0.5005252361297607,\n",
       " 0.06498408317565918,\n",
       " 0.44028112292289734,\n",
       " -0.3917606472969055,\n",
       " 0.1493881642818451,\n",
       " 1.5928112268447876,\n",
       " 0.3763121962547302,\n",
       " -0.4503709673881531,\n",
       " -0.38331708312034607,\n",
       " -0.10168121755123138,\n",
       " 0.44912707805633545,\n",
       " 0.9801600575447083,\n",
       " 0.4143023192882538,\n",
       " -0.06020985543727875,\n",
       " 0.14566552639007568,\n",
       " 0.12987308204174042,\n",
       " 0.3414257764816284,\n",
       " -0.646064043045044,\n",
       " -0.08601979911327362,\n",
       " -0.24407890439033508,\n",
       " 0.44008669257164,\n",
       " 0.9036880135536194,\n",
       " -0.024904891848564148,\n",
       " -0.1689210832118988,\n",
       " 0.6222386360168457,\n",
       " -1.0070306062698364,\n",
       " -0.052881911396980286,\n",
       " -0.5809668302536011,\n",
       " 0.0913909375667572,\n",
       " 0.9777870178222656,\n",
       " -0.8195319771766663,\n",
       " 0.7784957885742188,\n",
       " -0.12171696126461029,\n",
       " -0.5455284118652344,\n",
       " 0.9356669783592224,\n",
       " 0.20185357332229614,\n",
       " 0.7055404186248779,\n",
       " 0.4768453538417816,\n",
       " -0.5408782958984375,\n",
       " 0.47605398297309875,\n",
       " 1.203736662864685,\n",
       " -1.3356297016143799,\n",
       " -1.149970293045044,\n",
       " 0.5933078527450562,\n",
       " -0.7541511058807373,\n",
       " -0.010523848235607147,\n",
       " -0.6635218262672424,\n",
       " -0.2689937949180603,\n",
       " -0.2305022031068802,\n",
       " -0.6719322800636292,\n",
       " 0.6919701099395752,\n",
       " 1.2547575235366821,\n",
       " -0.597781777381897,\n",
       " 0.48276934027671814,\n",
       " 0.780577540397644,\n",
       " -0.7488816976547241,\n",
       " -0.35157889127731323,\n",
       " -0.21298976242542267,\n",
       " -0.3439244031906128,\n",
       " -0.1445944756269455,\n",
       " -0.2797071933746338,\n",
       " 0.15932752192020416,\n",
       " -0.03392411023378372,\n",
       " -0.15150393545627594,\n",
       " -0.3445291519165039,\n",
       " 0.9656757116317749,\n",
       " 0.11015340685844421,\n",
       " 0.1851874440908432,\n",
       " 0.19162926077842712,\n",
       " -0.7223946452140808,\n",
       " -0.6031566262245178,\n",
       " 0.6567255258560181,\n",
       " -0.4083419442176819,\n",
       " 1.68907630443573,\n",
       " -0.08843859285116196,\n",
       " 0.5298622250556946,\n",
       " 0.45894065499305725,\n",
       " 0.09147440642118454,\n",
       " 0.12044152617454529,\n",
       " -0.23471368849277496,\n",
       " -0.23757262527942657,\n",
       " 0.746601402759552,\n",
       " 0.36828041076660156,\n",
       " -0.0683094710111618,\n",
       " -0.32476356625556946,\n",
       " -0.29789984226226807,\n",
       " -0.4242316484451294,\n",
       " 1.0512315034866333,\n",
       " 1.2462561130523682,\n",
       " 0.4928674101829529,\n",
       " 0.33101338148117065,\n",
       " -0.6779442429542542,\n",
       " -0.0020301640033721924,\n",
       " -0.3759559690952301,\n",
       " -0.8933094143867493,\n",
       " -0.20078402757644653,\n",
       " 0.13268910348415375,\n",
       " 0.0036692507565021515,\n",
       " -1.1012006998062134,\n",
       " -0.5505877733230591,\n",
       " -0.6690686345100403,\n",
       " -0.35175764560699463,\n",
       " -0.3531993627548218,\n",
       " -1.1194424629211426,\n",
       " 0.23173613846302032,\n",
       " 0.5721845626831055,\n",
       " -1.6449793577194214,\n",
       " -0.11712078005075455,\n",
       " -0.28994178771972656,\n",
       " -0.3343273997306824,\n",
       " -0.11961323022842407,\n",
       " -0.025194713845849037,\n",
       " -0.583025336265564,\n",
       " -0.663966715335846,\n",
       " -0.17947226762771606,\n",
       " -0.14656125009059906,\n",
       " 0.5229567289352417,\n",
       " 0.7922115921974182,\n",
       " 0.14173048734664917,\n",
       " -0.45060473680496216,\n",
       " -0.6296430826187134,\n",
       " 0.21450892090797424,\n",
       " 0.5511866211891174,\n",
       " -0.597366988658905,\n",
       " -0.9478820562362671,\n",
       " -0.3126716911792755,\n",
       " 0.36880016326904297,\n",
       " 0.2522083520889282,\n",
       " -0.3883894085884094,\n",
       " -1.6488450765609741,\n",
       " -0.18424773216247559,\n",
       " 0.10374307632446289,\n",
       " 0.623898446559906,\n",
       " -0.004264481365680695,\n",
       " 0.38334572315216064,\n",
       " -0.7933332920074463,\n",
       " 0.28710657358169556,\n",
       " 0.8277014493942261,\n",
       " -1.2192846536636353,\n",
       " 0.02312173694372177,\n",
       " 0.6936168074607849,\n",
       " -0.6828246712684631,\n",
       " -2.264225482940674,\n",
       " -0.5131387114524841,\n",
       " 0.5179283618927002,\n",
       " 0.5750632286071777,\n",
       " -0.5841541290283203,\n",
       " 0.8966342806816101,\n",
       " -0.04733448103070259,\n",
       " -0.2015240490436554,\n",
       " -0.2091940939426422,\n",
       " -0.45728468894958496,\n",
       " 0.20280759036540985,\n",
       " 0.2337988168001175,\n",
       " 0.14312338829040527,\n",
       " 0.7321423292160034,\n",
       " -0.7836079597473145,\n",
       " 0.003932744264602661,\n",
       " 0.7416870594024658,\n",
       " 0.3520488739013672,\n",
       " 0.6561542749404907,\n",
       " 3.5174248218536377,\n",
       " -0.11180109530687332,\n",
       " 0.2703627943992615,\n",
       " -0.31804990768432617,\n",
       " 0.2024000585079193,\n",
       " 0.4019613265991211,\n",
       " -1.0677554607391357,\n",
       " 0.09833987802267075,\n",
       " 0.09336930513381958,\n",
       " 0.11544661223888397,\n",
       " 0.49970608949661255,\n",
       " 0.10166844725608826,\n",
       " 0.4005911350250244,\n",
       " 0.2228279560804367,\n",
       " -0.649901807308197,\n",
       " 0.6568697094917297,\n",
       " 2.275555372238159,\n",
       " 0.5175516605377197,\n",
       " 0.7804374098777771,\n",
       " 0.08831167221069336,\n",
       " -0.58637535572052,\n",
       " -1.1253881454467773,\n",
       " 0.17490118741989136,\n",
       " -0.24460077285766602,\n",
       " 0.28293177485466003,\n",
       " -0.7224493026733398,\n",
       " -0.01742905005812645,\n",
       " -0.027672255411744118,\n",
       " 0.19988037645816803,\n",
       " -0.4851788580417633,\n",
       " 0.369351327419281,\n",
       " -0.35490748286247253]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers.pipelines import pipeline\n",
    "embedding_model = pipeline(\n",
    "  \"feature-extraction\",\n",
    "  model=\"bert-base-chinese\",\n",
    ")\n",
    "embs = embedding_model('今天天气很好')\n",
    "embs[0][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用自己写的代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 ['今天天气很好']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00, 25.99it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([-3.55384529e-01,  4.80776519e-01,  2.03191042e-01,  5.69145441e-01,\n",
       "        1.03332102e+00, -8.31352532e-01,  3.44573915e-01, -9.09208477e-01,\n",
       "       -9.92007971e-01,  2.35167518e-01,  2.09469616e-01,  5.68382740e-01,\n",
       "        8.02978396e-01,  1.73730347e-02,  1.89004219e+00, -6.49705052e-01,\n",
       "        1.04517674e+00, -1.39936829e+00, -1.04959595e+00,  8.10109198e-01,\n",
       "       -5.84370732e-01,  6.96094871e-01, -1.12134135e+00, -7.91725695e-01,\n",
       "       -3.23840559e-01,  5.69603264e-01, -5.50973773e-01, -8.21848929e-01,\n",
       "        1.24794221e+00,  1.53681767e+00,  6.40594959e-01, -9.93969440e-02,\n",
       "       -4.52010155e-01,  2.92015135e-01,  7.91045308e-01, -8.40921044e-01,\n",
       "        9.99819994e-01,  2.85335004e-01, -5.19607663e-02, -8.49203408e-01,\n",
       "        3.04945409e-01, -1.27705187e-01, -6.47543073e-01,  1.54407310e+00,\n",
       "        1.25961339e+00,  1.14572667e-01, -2.32412770e-01,  9.55400243e-02,\n",
       "       -6.40026867e-01,  3.66865098e-01,  4.16713923e-01,  9.32367420e+00,\n",
       "        1.03354299e+00,  1.85678527e-01,  1.89410627e-01,  2.76047230e-01,\n",
       "        6.64440691e-01, -4.66857143e-02,  6.31038725e-01, -6.23991132e-01,\n",
       "        2.98271149e-01, -5.03737688e-01,  6.60546005e-01,  8.31252992e-01,\n",
       "        2.75911659e-01, -1.17395234e+00, -7.56494284e-01,  7.68641979e-02,\n",
       "       -1.67940545e+00,  1.32619470e-01, -2.74253666e-01, -2.18807936e-01,\n",
       "        3.73212963e-01,  6.26897812e-01, -3.27013373e-01,  1.58379722e+00,\n",
       "       -1.13886602e-01,  4.92431521e-01, -2.65920550e-01,  3.41454089e-01,\n",
       "        3.78805518e-01,  4.62348536e-02, -3.40375721e-01, -8.77115130e-02,\n",
       "       -3.70056629e-01, -1.68885142e-01, -1.16602039e+00, -2.90656877e+00,\n",
       "        8.47451270e-01,  9.48733687e-02, -2.46993273e-01, -4.89185810e-01,\n",
       "        8.51765350e-02, -1.33873552e-01,  8.44248295e-01,  4.16616708e-01,\n",
       "        4.57092971e-01,  2.36543670e-01,  3.07792962e-01, -1.31544769e+00,\n",
       "       -1.01590082e-01,  1.51064801e+00,  3.59176934e-01,  4.42677408e-01,\n",
       "       -1.44729853e+00, -4.08975840e-01, -2.40012527e-01,  5.49148321e-02,\n",
       "       -6.50841594e-01, -5.00489026e-02, -6.89666331e-01, -3.25010061e-01,\n",
       "        2.03984678e-01,  8.82413983e-01, -6.61309004e-01,  2.25584522e-01,\n",
       "        2.18766928e-01,  2.73255944e-01, -2.17579460e+00, -4.99916315e-01,\n",
       "        3.93184692e-01, -9.43138480e-01, -3.28332126e-01, -1.28577816e+00,\n",
       "        5.69034040e-01,  3.45686376e-01,  2.69601434e-01,  2.62139499e-01,\n",
       "        4.34873015e-01, -3.65725130e-01,  8.20086837e-01,  2.07879856e-01,\n",
       "       -1.32093871e+00, -7.47567177e-01,  1.31422687e+00,  2.50767648e-01,\n",
       "       -3.30215216e-01, -5.80924228e-02, -9.03824389e-01, -3.60766232e-01,\n",
       "       -1.90446258e+00, -1.04850151e-01, -4.10166323e-01,  1.22125477e-01,\n",
       "        1.87096846e+00,  3.11692148e-01,  1.64286971e-01, -3.37072492e-01,\n",
       "        7.48779833e-01, -6.67922854e-01, -9.93947208e-01,  2.94583499e-01,\n",
       "       -3.38152170e-01, -9.63394701e-01,  4.56263870e-02, -1.51604998e+00,\n",
       "       -1.17092296e-01, -5.49453020e-01, -1.28028914e-01, -5.13471067e-01,\n",
       "        4.03162718e-01, -2.79232860e-03,  2.52382338e-01, -3.10634255e-01,\n",
       "       -5.98223627e-01, -5.56274295e-01,  4.36354578e-01,  2.27762341e-01,\n",
       "        1.43654197e-02,  9.56721485e-01, -1.74762681e-01,  8.63016129e-01,\n",
       "       -2.29585364e-01,  3.84452760e-01, -1.84897453e-01, -5.47178268e-01,\n",
       "       -1.09909999e+00, -1.77853256e-01, -1.11375105e+00, -4.74206060e-01,\n",
       "       -2.64047652e-01, -2.50807762e-01,  7.95432270e-01,  1.94953680e-01,\n",
       "       -1.20660257e+00, -6.36751413e-01,  2.08240658e-01,  5.32436669e-01,\n",
       "        3.30386907e-01, -4.25612628e-02,  5.53576350e-02,  4.96687591e-01,\n",
       "       -3.03705907e+00, -1.59182444e-01,  9.58822072e-01, -2.22434476e-02,\n",
       "       -2.45392621e-01,  2.86605954e-01, -7.77101517e-01,  5.53361297e-01,\n",
       "       -1.29790589e-01, -8.81321311e-01, -1.99851647e-01, -1.72416359e-01,\n",
       "        1.08725958e-01,  6.30171299e-02, -3.30199659e-01,  5.10281175e-02,\n",
       "        1.40925968e+00,  1.43780038e-01, -5.99443853e-01, -6.31645918e-02,\n",
       "       -1.19732663e-01,  3.32741588e-01,  4.16801065e-01,  7.57829547e-01,\n",
       "        1.14060962e+00,  1.67296290e-01, -4.57886457e-01, -6.09812677e-01,\n",
       "       -2.95664258e-02, -2.14917421e-01,  3.04132588e-02, -4.70100075e-01,\n",
       "       -4.64645386e-01,  2.18424916e-01,  3.79062980e-01,  1.21619892e+00,\n",
       "       -1.70678347e-01, -8.73481214e-01,  9.44029987e-01, -7.34499454e-01,\n",
       "        8.97798985e-02, -7.79772282e-01, -4.24833357e-01,  1.31359363e+00,\n",
       "        9.15697634e-01,  1.07057655e+00,  1.58636498e+00,  1.98154211e-01,\n",
       "       -2.15137675e-01, -1.08508617e-01, -4.08824801e-01,  2.09505856e-02,\n",
       "        2.95183569e-01,  9.98546124e-01, -4.07978475e-01,  1.37440205e+00,\n",
       "        5.03048122e-01,  7.44308949e-01,  7.53071070e-01, -1.22383559e+00,\n",
       "       -1.44150332e-01, -1.18210709e+00, -6.06961370e-01,  4.23833355e-02,\n",
       "       -2.40449369e-01,  5.93999922e-01, -6.61614060e-01, -7.25113034e-01,\n",
       "        7.45031238e-01, -2.36632928e-01,  8.64983678e-01,  6.65703595e-01,\n",
       "        1.06588912e+00,  2.48341113e-02,  1.20206988e+00, -3.19217503e-01,\n",
       "       -5.65296039e-02,  9.05924812e-02, -2.22690865e-01,  1.41843230e-01,\n",
       "        4.55270678e-01, -2.03216508e-01, -1.63747501e+00, -7.22565532e-01,\n",
       "       -3.89891744e-01, -1.33436024e-01,  1.66045225e+00, -6.27508521e-01,\n",
       "       -4.43952799e-01, -1.05126357e+00, -1.01283617e-01, -5.18866658e-01,\n",
       "       -7.28588820e-01,  2.24772632e-01, -9.17235985e-02,  5.06578565e-01,\n",
       "       -2.50512064e-01, -5.43773651e-01, -1.22093606e+00, -1.69058099e-01,\n",
       "        1.29050565e+00,  1.47462666e-01,  1.51963246e+00, -1.58610985e-01,\n",
       "        3.90736192e-01,  5.30063272e-01,  6.74434364e-01,  1.25373304e-01,\n",
       "       -2.16638714e-01, -2.87551224e-01, -6.84614241e-01,  4.44118857e-01,\n",
       "       -1.81230292e-01,  3.67646605e-01,  4.89782244e-01, -2.94144273e-01,\n",
       "       -4.70098794e-01, -9.00112808e-01, -1.38953969e-01,  1.83568895e-01,\n",
       "        5.83291471e-01,  5.99437237e-01, -1.20411825e+00, -2.79898077e-01,\n",
       "       -1.52667627e-01, -7.30083942e-01, -7.13747263e-01, -1.44499019e-01,\n",
       "        1.36161789e-01,  1.40457475e+00,  9.97935534e-01,  5.39592683e-01,\n",
       "       -6.30727232e-01, -1.37424976e-01,  2.87572984e-02,  9.25893784e-02,\n",
       "       -5.60487866e-01,  6.56757772e-01,  1.58807099e-01,  1.84339777e-01,\n",
       "        3.98602545e-01, -6.58923864e-01,  4.20951128e-01, -6.39107406e-01,\n",
       "       -9.96746123e-01,  5.49002111e-01, -1.08409190e+00, -6.19152784e-01,\n",
       "        6.22367203e-01, -6.12523556e-01,  4.63636518e-02, -2.92708874e-01,\n",
       "        5.11120498e-01,  5.89367747e-01,  7.91649595e-02, -4.97043803e-02,\n",
       "       -4.63703990e-01,  1.59219652e-02, -7.61799365e-02, -8.88544083e-01,\n",
       "        1.62717596e-01,  4.99663413e-01, -1.42504275e+00,  1.35019159e+00,\n",
       "       -3.48242909e-01,  4.39016640e-01,  1.46918106e+00, -5.05117655e-01,\n",
       "       -2.20702440e-02,  2.43504584e-01, -2.78584242e-01, -1.28004873e+00,\n",
       "        1.34008378e-01,  1.12033218e-01,  2.91848481e-02, -1.45593688e-01,\n",
       "        8.29828203e-01, -1.84820235e-01, -3.24758857e-01, -5.37453413e-01,\n",
       "        7.70784259e-01, -2.03493047e+00, -2.14495331e-01,  1.39220029e-01,\n",
       "        5.78300297e-01,  8.36972296e-01,  1.08407092e+00, -3.21400434e-01,\n",
       "       -4.44759488e-01,  8.20001811e-02,  3.59182991e-02, -4.96311331e+00,\n",
       "       -1.27128959e+00,  5.55770457e-01, -7.84886181e-02, -6.22939840e-02,\n",
       "        4.80989814e-01, -2.35050291e-01,  1.63564011e-02,  3.78807187e-01,\n",
       "       -1.12807047e+00, -1.55132139e+00,  2.59593546e-01, -3.40129584e-02,\n",
       "       -1.26450151e-01,  3.34346116e-01,  3.92391264e-01, -5.66630960e-01,\n",
       "        8.78034592e-01,  2.94952720e-01,  1.15885448e+00, -1.23547755e-01,\n",
       "       -3.88215035e-01, -1.14655221e+00,  7.38938689e-01,  6.79993272e-01,\n",
       "        3.37589830e-01, -4.59593356e-01, -5.27400732e-01,  9.51763511e-01,\n",
       "        1.48024678e+00, -2.96633899e-01, -7.07442611e-02,  8.92492175e-01,\n",
       "       -3.50919962e-01, -2.44804785e-01,  2.71799237e-01,  3.50432992e-02,\n",
       "        1.18952535e-01,  4.80067790e-01, -1.64013410e+00, -1.67392641e-01,\n",
       "        6.38630271e-01, -1.20504582e+00, -3.32401514e-01,  2.06632614e-02,\n",
       "       -4.57640886e-01,  1.30490959e-03, -3.57385397e-01,  2.79984117e-01,\n",
       "       -3.98795336e-01,  1.23912156e+00,  2.55176634e-01,  8.25187027e-01,\n",
       "       -2.57303625e-01,  1.78916216e-01, -4.64130402e-01,  5.73767722e-01,\n",
       "       -2.74946511e-01, -1.27525723e+00,  5.79304025e-02, -6.98343396e-01,\n",
       "        1.35587752e+00,  8.13901484e-01, -3.22254747e-01, -5.00108600e-01,\n",
       "       -1.96520969e-01, -5.86629450e-01,  7.53006265e-02,  1.88356683e-01,\n",
       "       -6.42180741e-01, -8.52063522e-02, -1.10886529e-01,  8.56317282e-01,\n",
       "        3.13061297e-01,  1.82576090e-01, -1.45954132e+00,  1.72879434e+00,\n",
       "        6.97734833e-01, -2.40883395e-01, -2.05038965e-01,  1.15862533e-01,\n",
       "        5.27311444e-01,  6.50174499e-01, -1.08964331e-02, -6.05203986e-01,\n",
       "        8.99208307e-01,  5.65888047e-01, -8.21977139e-01,  1.36498237e+00,\n",
       "        1.16980588e+00, -2.74874032e-01,  7.70460814e-03, -6.28133059e-01,\n",
       "        3.60864922e-02, -1.45340398e-01, -8.05307388e-01, -1.47801125e+00,\n",
       "        6.26100421e-01, -1.53588697e-01,  3.13873482e+00, -1.81508511e-01,\n",
       "        7.94228077e-01, -3.30007136e-01,  6.33151829e-02, -4.04908329e-01,\n",
       "        7.39654005e-01, -2.10926875e-01, -1.31263852e+00, -4.86616157e-02,\n",
       "        2.17737049e-01, -8.94559681e-01, -6.97525859e-01, -2.97585964e-01,\n",
       "        6.22664213e-01,  7.26455331e-01,  5.40360212e-01,  3.03418994e-01,\n",
       "       -4.64780241e-01,  1.30079836e-02,  7.18824685e-01, -6.10180974e-01,\n",
       "       -3.25367361e-01, -4.20754522e-01, -4.89156321e-02, -3.74795437e-01,\n",
       "       -5.29592395e-01, -1.84048727e-01, -1.06069997e-01, -3.71858209e-01,\n",
       "       -5.33397853e-01, -7.14239776e-02,  4.76990581e-01,  2.40625769e-01,\n",
       "       -8.90099168e-01,  2.59172261e-01,  1.56004697e-01, -1.16491318e-01,\n",
       "       -9.19026136e-01,  3.93230677e-01,  1.83051303e-02, -2.40085676e-01,\n",
       "       -1.53568182e-02, -1.12023091e+00,  9.18394327e-03, -3.68659288e-01,\n",
       "        9.12545770e-02,  1.14450797e-01, -2.46919431e-02, -6.31182790e-01,\n",
       "       -2.52809823e-01, -6.02238238e-01,  3.72880459e-01,  1.01562154e+00,\n",
       "        1.67383265e+00, -9.93206978e-01, -2.37784609e-02, -9.74581838e-02,\n",
       "        3.93737078e-01,  2.31291890e-01,  7.22389668e-02, -6.99614048e-01,\n",
       "       -5.49645960e-01, -1.96796894e-01, -5.85764110e-01, -4.49068815e-01,\n",
       "       -3.48927945e-01,  1.23503566e+00, -8.41083348e-01,  6.37933969e-01,\n",
       "       -1.08143651e+00,  8.33879411e-02, -1.79473490e-01,  1.77266195e-01,\n",
       "       -8.06364238e-01, -3.82869869e-01,  1.18736908e-01, -4.26557958e-01,\n",
       "        3.49204481e-01, -1.91318601e-01,  6.90507412e-01,  6.16620660e-01,\n",
       "       -2.74548113e-01,  1.02417946e-01, -5.07093310e-01, -8.10515806e-02,\n",
       "       -1.04225159e+00, -1.10036099e+00, -3.80112439e-01, -2.72335589e-01,\n",
       "        9.55924541e-02,  4.16200310e-01,  5.63382387e-01,  5.00525236e-01,\n",
       "        6.49840832e-02,  4.40281123e-01, -3.91760647e-01,  1.49388164e-01,\n",
       "        1.59281123e+00,  3.76312196e-01, -4.50370967e-01, -3.83317083e-01,\n",
       "       -1.01681218e-01,  4.49127078e-01,  9.80160058e-01,  4.14302319e-01,\n",
       "       -6.02098554e-02,  1.45665526e-01,  1.29873082e-01,  3.41425776e-01,\n",
       "       -6.46064043e-01, -8.60197991e-02, -2.44078904e-01,  4.40086693e-01,\n",
       "        9.03688014e-01, -2.49048918e-02, -1.68921083e-01,  6.22238636e-01,\n",
       "       -1.00703061e+00, -5.28819114e-02, -5.80966830e-01,  9.13909376e-02,\n",
       "        9.77787018e-01, -8.19531977e-01,  7.78495789e-01, -1.21716961e-01,\n",
       "       -5.45528412e-01,  9.35666978e-01,  2.01853573e-01,  7.05540419e-01,\n",
       "        4.76845354e-01, -5.40878296e-01,  4.76053983e-01,  1.20373666e+00,\n",
       "       -1.33562970e+00, -1.14997029e+00,  5.93307853e-01, -7.54151106e-01,\n",
       "       -1.05238482e-02, -6.63521826e-01, -2.68993795e-01, -2.30502203e-01,\n",
       "       -6.71932280e-01,  6.91970110e-01,  1.25475752e+00, -5.97781777e-01,\n",
       "        4.82769340e-01,  7.80577540e-01, -7.48881698e-01, -3.51578891e-01,\n",
       "       -2.12989762e-01, -3.43924403e-01, -1.44594476e-01, -2.79707193e-01,\n",
       "        1.59327522e-01, -3.39241102e-02, -1.51503935e-01, -3.44529152e-01,\n",
       "        9.65675712e-01,  1.10153407e-01,  1.85187444e-01,  1.91629261e-01,\n",
       "       -7.22394645e-01, -6.03156626e-01,  6.56725526e-01, -4.08341944e-01,\n",
       "        1.68907630e+00, -8.84385929e-02,  5.29862225e-01,  4.58940655e-01,\n",
       "        9.14744064e-02,  1.20441526e-01, -2.34713688e-01, -2.37572625e-01,\n",
       "        7.46601403e-01,  3.68280411e-01, -6.83094710e-02, -3.24763566e-01,\n",
       "       -2.97899842e-01, -4.24231648e-01,  1.05123150e+00,  1.24625611e+00,\n",
       "        4.92867410e-01,  3.31013381e-01, -6.77944243e-01, -2.03016400e-03,\n",
       "       -3.75955969e-01, -8.93309414e-01, -2.00784028e-01,  1.32689103e-01,\n",
       "        3.66925076e-03, -1.10120070e+00, -5.50587773e-01, -6.69068635e-01,\n",
       "       -3.51757646e-01, -3.53199363e-01, -1.11944246e+00,  2.31736138e-01,\n",
       "        5.72184563e-01, -1.64497936e+00, -1.17120780e-01, -2.89941788e-01,\n",
       "       -3.34327400e-01, -1.19613230e-01, -2.51947138e-02, -5.83025336e-01,\n",
       "       -6.63966715e-01, -1.79472268e-01, -1.46561250e-01,  5.22956729e-01,\n",
       "        7.92211592e-01,  1.41730487e-01, -4.50604737e-01, -6.29643083e-01,\n",
       "        2.14508921e-01,  5.51186621e-01, -5.97366989e-01, -9.47882056e-01,\n",
       "       -3.12671691e-01,  3.68800163e-01,  2.52208352e-01, -3.88389409e-01,\n",
       "       -1.64884508e+00, -1.84247732e-01,  1.03743076e-01,  6.23898447e-01,\n",
       "       -4.26448137e-03,  3.83345723e-01, -7.93333292e-01,  2.87106574e-01,\n",
       "        8.27701449e-01, -1.21928465e+00,  2.31217369e-02,  6.93616807e-01,\n",
       "       -6.82824671e-01, -2.26422548e+00, -5.13138711e-01,  5.17928362e-01,\n",
       "        5.75063229e-01, -5.84154129e-01,  8.96634281e-01, -4.73344810e-02,\n",
       "       -2.01524049e-01, -2.09194094e-01, -4.57284689e-01,  2.02807590e-01,\n",
       "        2.33798817e-01,  1.43123388e-01,  7.32142329e-01, -7.83607960e-01,\n",
       "        3.93274426e-03,  7.41687059e-01,  3.52048874e-01,  6.56154275e-01,\n",
       "        3.51742482e+00, -1.11801095e-01,  2.70362794e-01, -3.18049908e-01,\n",
       "        2.02400059e-01,  4.01961327e-01, -1.06775546e+00,  9.83398780e-02,\n",
       "        9.33693051e-02,  1.15446612e-01,  4.99706089e-01,  1.01668447e-01,\n",
       "        4.00591135e-01,  2.22827956e-01, -6.49901807e-01,  6.56869709e-01,\n",
       "        2.27555537e+00,  5.17551661e-01,  7.80437410e-01,  8.83116722e-02,\n",
       "       -5.86375356e-01, -1.12538815e+00,  1.74901187e-01, -2.44600773e-01,\n",
       "        2.82931775e-01, -7.22449303e-01, -1.74290501e-02, -2.76722554e-02,\n",
       "        1.99880376e-01, -4.85178858e-01,  3.69351327e-01, -3.54907483e-01],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import BertTokenizer, BertModel\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "# 加载文件\n",
    "sentences = ['今天天气很好']\n",
    "\n",
    "# 准备模型\n",
    "model_name = \"bert-base-chinese\"\n",
    "model = BertModel.from_pretrained(model_name)\n",
    "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
    "\n",
    "# 设置设备\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "model.eval() \n",
    "\n",
    "# 转换为词向量\n",
    "batch_size = 16  # 批大小\n",
    "data_loader = DataLoader(sentences, batch_size=batch_size)\n",
    "for batch in data_loader:\n",
    "    print(len(batch), batch)\n",
    "cls_embeddings = []\n",
    "for batch_sentences in tqdm(data_loader):\n",
    "    inputs = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors=\"pt\", max_length=512)\n",
    "    inputs.to(device)\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "    cls_embeddings.append(outputs.last_hidden_state[:, 0].cpu().numpy()) # 只取CLS对应的向量\n",
    "\n",
    "cls_embeddings = np.vstack(cls_embeddings)\n",
    "cls_embeddings[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array_equal(\n",
    "  np.array(embs[0][0]),\n",
    "  cls_embeddings[0]\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
